import ast
import operator as op
import sys  # 添加sys导入
import logging
from typing import Any, Dict, Callable, Union, List, Tuple
from dataclasses import dataclass
import pandas as pd
import astunparse
from .indicators import IndicatorService  # 引入IndicatorService
from src.support.log.logger import logger

@dataclass
class IndicatorFunction:
    """指标函数描述"""
    name: str
    func: Callable
    params: Dict[str, type]
    description: str

class RuleParser:
    """规则解析引擎核心类"""
    
    @staticmethod
    def validate_syntax(rule: str) -> Tuple[bool, str]:
        """验证规则语法（不依赖数据）
        Args:
            rule: 规则表达式字符串
        Returns:
            (验证结果, 错误信息)
        """
        try:
            if not rule.strip():
                return False, "规则不能为空"
            ast.parse(rule, mode='eval')
            return True, "语法正确"
        except SyntaxError as e:
            logging.error(f"规则语法错误: {str(e)}")
            return False, f"规则语法错误: {str(e)}"
        except Exception as e:
            logging.error(f"规则验证异常: {str(e)}")
            return False, f"规则验证异常: {str(e)}"
            
    """规则解析引擎核心类"""
    
    OPERATORS = {
        ast.Gt: op.gt, 
        ast.Lt: op.lt,
        ast.Eq: op.eq,
        ast.GtE: op.ge,
        ast.LtE: op.le,
        ast.And: op.and_,
        ast.BitAnd: op.and_,
        ast.Or: op.or_,
        ast.Not: op.not_,
        ast.Add: op.add,
        ast.Sub: op.sub,
        ast.Mult: op.mul,
        ast.Div: op.truediv,
        ast.FloorDiv: op.floordiv,
        ast.Mod: op.mod,
        ast.Pow: op.pow
    }
    
    def __init__(self, data_provider: pd.DataFrame, indicator_service: IndicatorService, portfolio_manager: Any = None):
        """初始化解析器
        Args:
            data_provider: 提供OHLCV等市场数据的DataFrame
            indicator_service: 指标计算服务
            portfolio_manager: 投资组合管理器，用于获取COST、POSITION等变量
        """
        self.data = data_provider
        self.indicator_service = indicator_service
        self.portfolio_manager = portfolio_manager
        # 注册支持的指标函数
        self._indicators = {
            'REF': IndicatorFunction(
                name='REF',
                func=self._ref,
                params={'expr': str, 'period': int},
                description='引用前n期数据: REF(expr, period)'
            ),
            'RSI': IndicatorFunction(
                name='RSI',
                func=lambda series, period: self.indicator_service.calculate_indicator('rsi', series, self.current_index, period),
                params={'series': pd.Series, 'period': int},
                description='相对强弱指数: RSI(series, period)'
            )
        }
        self.series_cache = {}  # 序列缓存字典
        self.value_cache = {}   # 值缓存字典
        self.current_index = 0  # 当前计算位置
        self.max_recursion_depth = 100  # 最大递归深度
        self.recursion_counter = 0     # 递归计数器
        self.cache_hits = 0            # 缓存命中统计
        self.cache_misses = 0          # 缓存未命中统计
    
    def evaluate_at(self, rule: str, index: int) -> bool:
        """在指定K线位置评估规则
        Args:
            rule: 规则表达式字符串
            index: 数据索引位置
        Returns:
            规则评估结果(bool)
        """
        self.current_index = index
        self.recursion_counter = 0  # 重置递归计数器
        result = self.parse(rule)
        if not isinstance(result, bool):
            return bool(result)
        return result
        
    def parse(self, rule: str, mode: str = 'rule') -> Union[bool, float]:
        """解析规则表达式
        Args:
            rule: 规则表达式字符串，如"(SMA(5) > SMA(20)) & (RSI(14) < 30)"
            mode: 解析模式 ('rule'返回bool, 'ref'返回原始数值)
        Returns:
            规则评估结果(bool)或原始数值(float)
        Raises:
            SyntaxError: 规则语法错误时抛出
            ValueError: 指标参数错误时抛出
            RecursionError: 递归深度超过限制时抛出
        """
        try:
            if not rule.strip():
                return False if mode == 'rule' else 0.0
            tree = ast.parse(rule, mode='eval')
            result = self._eval(tree.body)
            final_result = bool(result) if mode == 'rule' else result
            if mode == 'rule':
                pass
                # logger.info(f"[RULE_RESULT] {rule} = {final_result}")
            return final_result
        except RecursionError:
            raise RecursionError("递归深度超过限制，请简化规则表达式")
        except Exception as e:
            raise SyntaxError(f"规则解析失败: {str(e)}") from e
        
    def clear_cache(self):
        """清除序列缓存"""
        self.series_cache = {}
    
    def get_or_create_series(self, expr: str) -> pd.Series:
        """获取或创建指标序列"""
        if expr in self.series_cache:
            return self.series_cache[expr]
        
        # 解析表达式并计算序列
        tree = ast.parse(expr, mode='eval')
        series = self._eval(tree.body)
        
        if not isinstance(series, pd.Series):
            raise ValueError(f"表达式 '{expr}' 未返回序列")
            
        self.series_cache[expr] = series
        return series
    
    def _node_to_expr(self, node) -> str:
        """将AST节点转换为表达式字符串，生成更简洁的列名"""
        if isinstance(node, ast.Compare):
            # 对于比较运算，生成更简洁的表达式
            left = self._node_to_expr_simple(node.left)
            right = self._node_to_expr_simple(node.comparators[0])
            op = self._get_operator_symbol(node.ops[0])
            return f"{left} {op} {right}"
        elif isinstance(node, ast.BinOp):
            # 对于二元运算，生成更简洁的表达式
            left = self._node_to_expr_simple(node.left)
            right = self._node_to_expr_simple(node.right)
            op = self._get_operator_symbol(node.op)
            return f"{left} {op} {right}"
        elif isinstance(node, ast.UnaryOp):
            # 对于一元运算
            operand = self._node_to_expr_simple(node.operand)
            op = self._get_operator_symbol(node.op)
            return f"{op}{operand}"
        else:
            # 其他情况使用原始方法
            expr = astunparse.unparse(node).strip()
            # 处理函数调用参数间的多余空格
            if isinstance(node, ast.Call):
                expr = expr.replace(', ', ',')
            return expr
    
    def _node_to_expr_simple(self, node) -> str:
        """生成更简洁的表达式字符串，用于内部运算"""
        if isinstance(node, ast.Name):
            return node.id
        elif isinstance(node, ast.Constant):
            return str(node.value)
        elif isinstance(node, ast.BinOp):
            left = self._node_to_expr_simple(node.left)
            right = self._node_to_expr_simple(node.right)
            op = self._get_operator_symbol(node.op)
            
            # 简化括号逻辑 - 只在真正必要时添加括号
            left_needs_parens = self._needs_parentheses(node.left, node.op, is_left=True)
            right_needs_parens = self._needs_parentheses(node.right, node.op, is_left=False)
            
            if left_needs_parens:
                left = f"({left})"
            if right_needs_parens:
                right = f"({right})"
                
            return f"{left} {op} {right}"
        elif isinstance(node, ast.UnaryOp):
            operand = self._node_to_expr_simple(node.operand)
            op = self._get_operator_symbol(node.op)
            return f"{op}{operand}"
        elif isinstance(node, ast.Call):
            # 函数调用保持原样
            func_name = node.func.id
            args = [self._node_to_expr_simple(arg) for arg in node.args]
            return f"{func_name}({','.join(args)})"
        else:
            return astunparse.unparse(node).strip()
    
    def _needs_parentheses(self, child_node, parent_op, is_left=True) -> bool:
        """判断子节点是否需要括号
        Args:
            child_node: 子AST节点
            parent_op: 父节点的操作符
            is_left: 是否为左操作数
        Returns:
            是否需要添加括号
        """
        if not isinstance(child_node, ast.BinOp):
            return False
            
        # 获取操作符优先级
        op_precedence = {
            ast.Pow: 4,
            ast.Mult: 3, ast.Div: 3, ast.FloorDiv: 3, ast.Mod: 3,
            ast.Add: 2, ast.Sub: 2,
            ast.Gt: 1, ast.Lt: 1, ast.Eq: 1, ast.GtE: 1, ast.LtE: 1
        }
        
        child_op_precedence = op_precedence.get(type(child_node.op), 0)
        parent_op_precedence = op_precedence.get(type(parent_op), 0)
        
        # 如果子操作符优先级更低，需要括号
        if child_op_precedence < parent_op_precedence:
            return True
            
        # 对于相同优先级的操作符，需要处理结合性
        if child_op_precedence == parent_op_precedence:
            # 对于左结合的运算符，右操作数需要括号
            if not is_left and parent_op_precedence in [2, 3]:  # +-*/等
                return True
            # 对于幂运算，左操作数需要括号
            if is_left and isinstance(parent_op, ast.Pow):
                return True
                
        return False
    
    def _get_operator_symbol(self, op_node) -> str:
        """获取运算符的符号表示"""
        if isinstance(op_node, ast.Add):
            return "+"
        elif isinstance(op_node, ast.Sub):
            return "-"
        elif isinstance(op_node, ast.Mult):
            return "*"
        elif isinstance(op_node, ast.Div):
            return "/"
        elif isinstance(op_node, ast.FloorDiv):
            return "//"
        elif isinstance(op_node, ast.Mod):
            return "%"
        elif isinstance(op_node, ast.Pow):
            return "**"
        elif isinstance(op_node, ast.Gt):
            return ">"
        elif isinstance(op_node, ast.Lt):
            return "<"
        elif isinstance(op_node, ast.Eq):
            return "=="
        elif isinstance(op_node, ast.GtE):
            return ">="
        elif isinstance(op_node, ast.LtE):
            return "<="
        elif isinstance(op_node, ast.USub):
            return "-"
        elif isinstance(op_node, ast.UAdd):
            return "+"
        elif isinstance(op_node, ast.Not):
            return "not "
        else:
            return str(op_node)
            
    def _store_expression_result(self, node, result, bool_only=False):
        """存储表达式结果到data
        Args:
            node: AST节点
            result: 计算结果
            bool_only: 是否为bool表达式(需要符号替换)
        """
        # 不存储数字常量（如-5）作为列
        if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
            return
        
        expr_str = self._node_to_expr(node)
        
        # 生成列名 - 保持原始表达式格式
        col_name = expr_str
        
        # 特殊处理：为COST和POSITION等特殊变量创建单独的列（即使值为0或None）
        if isinstance(node, ast.Name) and node.id in ['COST', 'POSITION']:
            # 为特定变量创建单独的列
            var_name = node.id
            if var_name not in self.data.columns:
                self.data[var_name] = [float('nan')] * len(self.data)
                self.data.attrs[f"{var_name}_expr"] = var_name
            if 0 <= self.current_index < len(self.data):
                # 即使result为0或None也存储
                self.data.at[self.current_index, var_name] = result if result is not None else float('nan')
            return  # 不继续存储为表达式列
        
        # 检查列是否已存在（跳过数字常量）
        if col_name not in self.data.columns and not (isinstance(node, ast.Constant) and isinstance(node.value, (int, float))):
            # 根据表达式类型初始化列
            if bool_only:
                self.data[col_name] = [False] * len(self.data)
            else:
                self.data[col_name] = [float('nan')] * len(self.data)
            # 添加表达式注释
            self.data.attrs[f"{col_name}_expr"] = expr_str
        
        # 存储结果（跳过数字常量）
        if 0 <= self.current_index < len(self.data) and not (isinstance(node, ast.Constant) and isinstance(node.value, (int, float))):
            self.data.at[self.current_index, col_name] = bool(result) if bool_only else result

    def _eval(self, node):
        """递归评估AST节点"""
        if isinstance(node, ast.Compare):
            left = self._eval(node.left)
            right = self._eval(node.comparators[0])
            result = self.OPERATORS[type(node.ops[0])](left, right)
            # 存储比较运算结果(bool值)
            self._store_expression_result(node, result, bool_only=True)
            # 只存储有意义的表达式（避免存储数字常量和中间表达式）
            # 不存储比较运算的子表达式，只存储最终布尔结果
            return result
        elif isinstance(node, ast.BoolOp):
            return self._eval_bool_op(node)
        elif isinstance(node, ast.Call):
            return self._eval_function_call(node)
        elif isinstance(node, ast.Name):
            return self._eval_variable(node)
        elif isinstance(node, ast.BinOp):
            left = self._eval(node.left)
            right = self._eval(node.right)
            
            # 处理除零错误
            if isinstance(node.op, (ast.Div, ast.FloorDiv)) and right == 0:
                # 获取当前时间信息
                current_time = ""
                if hasattr(self, 'data') and 'combined_time' in self.data.columns and 0 <= self.current_index < len(self.data):
                    current_time = self.data['combined_time'].iloc[self.current_index]
                
                # 获取表达式字符串用于更智能的错误处理
                expr_str = self._node_to_expr(node)
                
                # 特殊处理：对于COST/POSITION表达式，当POSITION为0时返回0.0（不生成信号）
                if 'COST' in expr_str and 'POSITION' in expr_str:
                    return 0.0  # 当持仓为0时，返回0.0表示不生成信号
                
                # 打印除零错误信息
                return 0.0  # 其他除零情况返回0
                
            return self.OPERATORS[type(node.op)](left, right)
        elif isinstance(node, ast.Constant):
            try:
                return float(node.value)
            except (TypeError, ValueError):
                return 0.0
        elif isinstance(node, ast.UnaryOp):
            # 处理一元运算符（如 -5, +10 等）
            operand = self._eval(node.operand)
            if isinstance(node.op, ast.USub):  # 负号
                return -operand
            elif isinstance(node.op, ast.UAdd):  # 正号
                return +operand
            elif isinstance(node.op, ast.Not):  # 逻辑非
                return not operand
            elif isinstance(node.op, ast.Invert):  # 按位取反 ~
                return ~int(operand) if operand is not None else None
            else:
                raise ValueError(f"不支持的一元运算符: {type(node.op)}")
        else:
            raise ValueError(f"不支持的AST节点类型: {type(node)}")

    def _eval_bool_op(self, node) -> bool:
        """评估逻辑运算符"""
        values = [self._eval(v) for v in node.values]
        result = self.OPERATORS[type(node.op)](*values)
        # 存储布尔运算结果(bool值)
        self._store_expression_result(node, result, bool_only=True)
        # 只存储布尔运算结果，不存储子表达式（避免存储中间表达式）
        return result
    
    def _eval_variable(self, node) -> float:
        """评估变量(从数据源获取或从portfolio_manager获取)"""
        var_name = node.id
        
        # 处理投资组合相关变量
        if var_name == 'COST' and self.portfolio_manager:
            # 获取持仓总成本
            result = self.portfolio_manager.get_total_cost()
            # 存储COST变量的值到对应列
            self._store_expression_result(node, result)
            return result
        elif var_name == 'POSITION' and self.portfolio_manager:
            # 获取当前标的的持仓数量
            # 需要从数据中获取当前标的代码
            if not self.data.empty and 'code' in self.data.columns:
                current_symbol = self.data['code'].iloc[self.current_index]
                position = self.portfolio_manager.get_position(current_symbol)
                result = position.quantity if position else 0.0
            else:
                result = 0.0
            # 存储POSITION变量的值到对应列
            self._store_expression_result(node, result)
            return result
        
        # 处理数据列变量
        if var_name not in self.data.columns:
            raise ValueError(f"数据中不存在列: {var_name}")
        value = self.data[var_name].iloc[self.current_index]
        if pd.isna(value):
            return 0.0  # 空值处理
        return float(value)
    
    def _eval_function_call(self, node):
        """评估指标函数调用"""
        # 检查递归深度
        self.recursion_counter += 1
        if self.recursion_counter > self.max_recursion_depth:
            self.recursion_counter -= 1  # 恢复计数器
            raise RecursionError(f"递归深度超过限制 ({self.max_recursion_depth})")
            
        try:
            func_name = node.func.id
            
        finally:
            self.recursion_counter -= 1  # 确保计数器减少
        
        # 记录函数参数
        args_str = ", ".join([self._node_to_expr(arg) for arg in node.args])
        
        
        # 特殊处理SMA指标的计算细节
        period = None
        if func_name.upper() == 'SMA':
            period_val = self._eval(node.args[1]) if len(node.args) > 1 else 5
            period = int(period_val) if isinstance(period_val, (int, float)) else 5
            data_column = self._node_to_expr(node.args[0]).strip('\"\'')
            start_idx = max(0, int(self.current_index) - period + 1)
            end_idx = int(self.current_index) + 1
            window_data = self.data[data_column].iloc[start_idx:end_idx]
            # logger.debug(f"[DEBUG] SMA计算窗口数据({period}期): {window_data.values}")
            # logger.debug(f"[DEBUG] SMA计算窗口索引: {max(0, self.current_index-period+1)}:{self.current_index+1}")
        
        # 特殊处理REF函数（需要解析器状态）
        if func_name == 'REF':
            if func_name not in self._indicators:
                raise ValueError(f"不支持的指标函数: {func_name}")
                
            indicator = self._indicators[func_name]
            
            if len(node.args) != 2:
                raise ValueError("REF需要2个参数 (REF(expr, period))")
                
            expr_node = node.args[0]
            period_node = node.args[1]
            
            # 获取表达式字符串
            expr_str = self._node_to_expr(expr_node)
            period = self._eval(period_node)
            
            if not isinstance(period, (int, float)):
                raise ValueError("REF周期必须是数字")
                
            return indicator.func(expr_str, int(period))
        
        # 其他指标函数委托给IndicatorService
        # 从第一个参数获取数据列名
        data_column = self._node_to_expr(node.args[0]).strip()
        # 移除可能的引号（兼容字符串字面量）
        if data_column.startswith('"') and data_column.endswith('"'):
            data_column = data_column[1:-1]
        elif data_column.startswith("'") and data_column.endswith("'"):
            data_column = data_column[1:-1]
            
        if data_column not in self.data.columns:
            raise ValueError(f"数据中不存在列: {data_column}")
        
        # 获取数据序列
        series = self.data[data_column]
        remaining_args = node.args[1:]
        
        # 生成缓存键：函数名+数据列+参数+当前索引
        remaining_args_str = [self._node_to_expr(arg) for arg in remaining_args]
        args_list = [data_column] + remaining_args_str
        args_str = ",".join(args_list)
        cache_key = f"{func_name}({args_str})@{self.current_index}"
            
        if cache_key in self.value_cache:
            self.cache_hits += 1
            cached_value = float(self.value_cache[cache_key])
            # logger.info(f"[CACHE_HIT] {func_name}({args_str})={cached_value}")
            if func_name.upper() == 'SMA':
                period_val = self._eval(node.args[1]) if len(node.args) > 1 else 5
                period = int(period_val) if isinstance(period_val, (int, float)) else 5
                data_column = self._node_to_expr(node.args[0]).strip('\"\'')
                # logger.info(f"[SMA_RESULT] SMA({data_column},{period})={cached_value}")
            return cached_value
        
        # 计算并缓存结果
        self.cache_misses += 1
        
        # 验证指标参数（特别是周期类参数）
        for arg_node in remaining_args:
            arg_value = self._eval(arg_node)
            if not isinstance(arg_value, (int, float)) or arg_value <= 0:
                raise ValueError(
                    f"函数 {func_name} 的参数必须是正数: {arg_value}"
                )
                
        # 检查数据长度是否满足指标计算要求
        min_required = self._get_min_data_requirement(
            func_name,
            *[self._eval(arg) for arg in remaining_args]
        )
        if self.current_index < min_required:
            return 0.0
            
        # 委托给IndicatorService计算指标（传递数据序列和剩余参数）
        try:
            # 特别记录SMA指标的计算细节
            if func_name.upper() == 'SMA':
                period_val = self._eval(remaining_args[0]) if remaining_args else 5
                period = int(period_val) if isinstance(period_val, (int, float)) else 5
                start_idx = max(0, int(self.current_index) - period + 1)
                end_idx = int(self.current_index) + 1
                # logger.debug(
                #     f"SMA计算详情: 周期={period}, "
                #     f"数据范围={start_idx}:{end_idx}, "
                #     f"当前值={series.iloc[self.current_index]}"
                # )
            
            # 确保current_index是整数
            current_index = int(self.current_index) if not isinstance(self.current_index, int) else self.current_index
            # 确保current_index是整数
            current_index = int(self.current_index) if not isinstance(self.current_index, int) else self.current_index
            result = self.indicator_service.calculate_indicator(
                func_name,
                series,  # 传递具体数据序列而非整个DataFrame
                current_index,
                *[self._eval(arg) for arg in remaining_args]  # 评估所有参数
            )
            
            
            if func_name.upper() == 'SMA':
                pass
                # logger.debug(f"SMA计算详情: {data_column},{period}={result}")
        except AttributeError as e:
            logging.error(f"不支持的指标函数: {func_name}, 错误: {str(e)}")
            raise ValueError(f"不支持的指标函数: {func_name}") from e
        except Exception as e:
            logging.error(
                f"指标计算失败: {func_name}({args_str}), "
                f"错误: {str(e)}, 位置={self.current_index}"
            )
            raise
        
        # 使用统一的安全转换方法
        try:
            result_float = self._safe_convert_to_float(
                result, 
                f"函数 {func_name} 的返回值"
            )
        except ValueError as e:
            # 添加额外上下文信息后重新抛出
            raise ValueError(
                f"指标函数 {func_name} 值转换失败: {str(e)}"
            ) from e
        
        # 缓存并返回结果
        self.value_cache[cache_key] = result_float
        # 存储指标计算结果到engine.data
        col_name = f"{func_name}({args_str})"
        
        # 严格检查列是否存在（包括属性和注释）
        col_exists = (
            col_name in self.data.columns and 
            f"{col_name}_expr" in self.data.attrs
        )
        
        if not col_exists:
            # 初始化列并填充NaN
            self.data[col_name] = [float('nan')] * len(self.data)
            # 添加表达式注释
            self.data.attrs[f"{col_name}_expr"] = f"{func_name}({args_str})"
        
        # 确保当前索引有效
        if 0 <= self.current_index < len(self.data):
            self.data.at[self.current_index, col_name] = result_float
            
        else:
            logger.error(f"无效索引 {self.current_index} 无法存储指标 {col_name}")
        
        return result_float
    
        
    def _safe_convert_to_float(self, value: Any, context: str = "") -> float:
        """安全转换为浮点数，包含详细错误处理
        Args:
            value: 需要转换的值
            context: 错误上下文描述
        Returns:
            转换后的浮点数(处理NaN为0.0)
        Raises:
            ValueError: 转换失败时抛出
        """
        from typing import Any
        import numpy as np
        
        # 处理NaN/None值
        if pd.isna(value) or value is None:
            return 0.0
            
        # 处理布尔值
        if isinstance(value, bool):
            return float(value)
            
        # 处理数字类型
        if isinstance(value, (int, float)):
            return float(value)
            
        # 处理字符串
        if isinstance(value, str):
            try:
                return float(value)
            except ValueError:
                raise ValueError(f"字符串无法转换为浮点数: {value} ({context})")
        
        # 处理Series类型
        if isinstance(value, pd.Series):
            if self.current_index < len(value):
                value = value.iloc[self.current_index]
            else:
                value = value.iloc[-1]
            return self._safe_convert_to_float(value, context)
        
        # 处理numpy类型
        if np and isinstance(value, (np.number, np.bool_, np.generic)):
            return float(value.item())
        
        # 处理可转换为float的类型
        if hasattr(value, '__float__'):
            try:
                return float(value)
            except (TypeError, ValueError) as e:
                raise ValueError(
                    f"类型转换失败: {type(value)} -> float (值: {value}, 上下文: {context})"
                ) from e
            
        raise ValueError(
            f"不支持的类型转换: {type(value)} -> float (值: {value}, 上下文: {context})"
        )

    def _get_min_data_requirement(self, func_name: str, *args) -> int:
        """获取指标函数的最小数据要求
        Args:
            func_name: 指标函数名
            *args: 指标参数
        Returns:
            最小需要的数据长度
        """
        func_name = func_name.lower()
        try:
            if func_name == 'sma':
                return int(float(args[0])) if args else 1
            elif func_name == 'rsi':
                return int(float(args[0])) if args else 14
            elif func_name == 'macd':
                return max(
                    int(float(args[0])) if len(args)>0 else 12,
                    int(float(args[1])) if len(args)>1 else 26,
                    int(float(args[2])) if len(args)>2 else 9
                )
            return 1  # 默认最小长度
        except (ValueError, TypeError):
            return 1  # 参数转换失败时返回默认值

    def _ref(self, expr: str, period: int) -> float:
        """引用前period期的指标值（保留在RuleParser中）
        1. 计算原始指标并存储
        2. 计算REF指标并存储
        """
        # logger.debug(f"[REF] 开始计算REF(expr={expr}, period={period})")
        
        # 先计算并存储原始指标
        if "(" in expr and ")" in expr:  # 如果是指标表达式
            original_result = self.parse(expr, mode='ref')
            original_col = expr
            if original_col not in self.data.columns:
                self.data[original_col] = None
            self.data.at[self.current_index, original_col] = original_result
        
        # 检查递归深度
        self.recursion_counter += 1
        if self.recursion_counter > self.max_recursion_depth:
            raise RecursionError(f"递归深度超过限制 ({self.max_recursion_depth})")
            
        if period < 0:
            raise ValueError("周期必须是非负数")
            
        # 保存当前索引
        original_index = self.current_index
        
        # 计算目标位置
        period_int = int(period) if period is not None else 0
        target_index = max(0, min(int(original_index) - period_int, len(self.data)-1))
        target_index = int(target_index)  # 确保转换为整数
        # logger.debug(f"[REF] 目标索引位置: {target_index} (当前索引: {original_index}, 回溯周期: {period})")
        
        # 回溯到历史位置计算表达式
        self.current_index = target_index
        try:
            # 尝试从缓存获取
            cache_key = f"REF({expr},{period})@{original_index}"
            if cache_key in self.value_cache:
                return float(self.value_cache[cache_key])
                
            # 使用完整parse流程解析表达式（确保嵌套指标计算也能存储结果）
            result = self.parse(expr, mode='ref')
            
            # 处理结果并缓存
            result_numeric = self._safe_convert_to_float(
                result,
                f"REF表达式 '{expr}'"
            )
            
            # 确保嵌套指标计算结果已存储
            if "(" in expr and ")" in expr:  # 如果是指标表达式
                nested_col = f"REF({expr},{period})"
                if nested_col not in self.data.columns:
                    self.data[nested_col] = None
                self.data.at[original_index, nested_col] = result_numeric
            
            self.value_cache[cache_key] = result_numeric
            # logger.info(f"[REF_RESULT] REF({expr},{period})={result_numeric} (from index {target_index})")
            return result_numeric
        except Exception as e:
            raise ValueError(f"REF函数计算失败: {str(e)}") from e
        finally:
            # 恢复原始位置
            self.current_index = original_index
            self.recursion_counter -= 1  # 减少递归计数器
